#include "palette_bmp.hpp"
#include "pch.hpp"
#include "utils.hpp"

using namespace std;
using namespace cv;

// NOTE good articles
// 1. https://blog.csdn.net/u012877472/article/details/50272771
// 2. https://www.cnblogs.com/Matrix_Yao/archive/2009/12/02/1615295.html
// 3. http://www.fysnet.net/bmpfile.htm


// NOTE bmp is LITTLE endian (opposite to png!)
inline void write_uint32(uchar *bytes, int *i, uint32_t n) {
    bytes[*i + 3] = (n >> 24) & 0xFF;
    bytes[*i + 2] = (n >> 16) & 0xFF;
    bytes[*i + 1] = (n >> 8) & 0xFF;
    bytes[*i + 0] = n & 0xFF;
    *i += 4;
}

inline void write_uint16(uchar *bytes, int *i, uint16_t n) {
    bytes[*i + 1] = (n >> 8) & 0xFF;
    bytes[*i + 0] = n & 0xFF;
    *i += 2;
}

inline uint32_t read_uint32(const uchar *bytes, int *i) {
    uint32_t n =
            bytes[*i + 0]
            + (((uint32_t) bytes[*i + 1]) << 8)
            + (((uint32_t) bytes[*i + 2]) << 16)
            + (((uint32_t) bytes[*i + 3]) << 24);
    *i += 4;
    return n;
}

inline uint16_t read_uint16(const uchar *bytes, int *i) {
    uint32_t n =
            bytes[*i + 0]
            + (((uint32_t) bytes[*i + 1]) << 8);
    *i += 2;
    return n;
}

int write_bmp_file_header(uchar *bytes, int begin_index, uint32_t file_size, uint32_t real_data_start_index) {
    int i = begin_index;

    // bfType
    bytes[i++] = 'B';
    bytes[i++] = 'M';

    // bfSize
    write_uint32(bytes, &i, file_size);

    // bfReserved1 (2 bytes)
    write_uint16(bytes, &i, 0);

    // bfReserved2 (2 bytes)
    write_uint16(bytes, &i, 0);

    // bfOffBits
    write_uint32(bytes, &i, real_data_start_index);

    return i;
}

int read_bmp_file_header(const uchar *bytes, int begin_index) {
    int i = begin_index;

    // bfType
    CV_Assert(bytes[i++] == 'B');
    CV_Assert(bytes[i++] == 'M');

    // bfSize
    uint32_t file_size = read_uint32(bytes, &i);
//    CV_Assert(file_size == bytes.size());

    // bfReserved1 (2 bytes)
    CV_Assert(read_uint16(bytes, &i) == 0);

    // bfReserved2 (2 bytes)
    CV_Assert(read_uint16(bytes, &i) == 0);

    // bfOffBits
    read_uint32(bytes, &i);

    return i;
}

int write_bmp_bitmap_information(uchar *bytes, int begin_index, Mat im_index, int palette_num) {
    int i = begin_index;

    // biSize
    write_uint32(bytes, &i, 40);

    // biWidth
    write_uint32(bytes, &i, im_index.cols);
    // biHeight
    // NOTE 如果为负说明正向，所以我们用负数
    write_uint32(bytes, &i, (uint32_t) (-im_index.rows));

    // biPlanes
    write_uint16(bytes, &i, 1);

    // biBitCount
    CV_Assert(im_index.type() == CV_8UC1);
    int bit_count = 8;
    write_uint16(bytes, &i, bit_count);

    // biCompression
    write_uint32(bytes, &i, 0);
    // biSizeImages
    write_uint32(bytes, &i, 0);

    const int px_per_meter = 3780; // 就是72dpi (google搜索 pixel per meter)
    // biXPelsPerMeter
    write_uint32(bytes, &i, px_per_meter);
    // biYPelsPerMeter
    write_uint32(bytes, &i, px_per_meter);

    // biClrUsed
    write_uint32(bytes, &i, palette_num);
    // biClrImportant
    write_uint32(bytes, &i, palette_num);

    return i;
}

struct ReadBmpBitmapInfo {
    int width, height, palette_num;
};

int read_bmp_bitmap_information(const uchar *bytes, int begin_index, ReadBmpBitmapInfo &out_info) {
    int i = begin_index;

    // biSize
    CV_Assert(40 == read_uint32(bytes, &i));

    // biWidth
    out_info.width = read_uint32(bytes, &i);
    CV_Assert(out_info.width > 0);
    CV_Assert(out_info.width % 4 == 0); // 否则会有填充padding，而我们底下没处理

    // biHeight
    int raw_height = (int) read_uint32(bytes, &i);
    CV_Assert(raw_height < 0); // 因为我们希望是正向的
    out_info.height = -raw_height;
    CV_Assert(out_info.height > 0);

    // biPlanes
    CV_Assert(1 == read_uint16(bytes, &i));

    // biBitCount
    int bit_count = 8;
    CV_Assert(bit_count == read_uint16(bytes, &i));

    // biCompression
    CV_Assert(0 == read_uint32(bytes, &i));
    // biSizeImages
    CV_Assert(0 == read_uint32(bytes, &i));

    // biXPelsPerMeter
    read_uint32(bytes, &i);
    // biYPelsPerMeter
    read_uint32(bytes, &i);

    // biClrUsed
    int raw_palette_num = (int) read_uint32(bytes, &i);
    out_info.palette_num = raw_palette_num == 0 ? 256 : raw_palette_num;
    // biClrImportant
    read_uint32(bytes, &i);

    return i;
}

int write_bmp_palette(uchar *bytes, int begin_index, Mat palette) {
    int i = begin_index;
    for (int j = 0; j < palette.rows; ++j) {
        Vec3b &p = palette.at<Vec3b>(j, 0);
        // blue, green, red, zero
        // note the 4th byte is "zero(reserved)" http://www.fysnet.net/bmpfile.htm
        bytes[i + 0] = p[2];
        bytes[i + 1] = p[1];
        bytes[i + 2] = p[0];
        bytes[i + 3] = 0;
        i += 4;
    }
    return i;
}

int read_bmp_palette(const uchar *bytes, int begin_index, ReadBmpBitmapInfo info, Mat &out_palette) {
    int i = begin_index;

    out_palette = Mat(info.palette_num, 1, CV_8UC3);

    for (int j = 0; j < info.palette_num; ++j) {
        Vec3b &p = out_palette.at<Vec3b>(j, 0);
        // blue, green, red, alpha
        p[2] = bytes[i + 0];
        p[1] = bytes[i + 1];
        p[0] = bytes[i + 2];
        i += 4; // NOTE 不是3，因为有被扔掉的alpha
    }

    return i;
}

int write_bmp_data(uchar *bytes, int begin_index, const Mat &im_index) {
    int i = begin_index;

    // 否则.data是不对的
    CV_Assert(im_index.isContinuous());

    const int num_bytes = im_index.rows * im_index.cols;
    // https://stackoverflow.com/questions/259297/how-do-you-copy-the-contents-of-an-array-to-a-stdvector-in-c-without-looping
    memcpy(&bytes[i], im_index.data, num_bytes);
    i += num_bytes;

    return i;
}

int read_bmp_data(const uchar *bytes, int begin_index, ReadBmpBitmapInfo info, Mat &out_im_index) {
//    Timer t("read_bmp_data");

    int i = begin_index;

    out_im_index = Mat(info.height, info.width, CV_8UC1);
//    t("after create Mat");

    const int num_bytes = out_im_index.rows * out_im_index.cols;
    // https://stackoverflow.com/questions/259297/how-do-you-copy-the-contents-of-an-array-to-a-stdvector-in-c-without-looping
    memcpy(out_im_index.data, &bytes[i], num_bytes);
    i += num_bytes;
//    t("after memcpy");

    return i;
}

cv::Mat imencode_palette_bmp(const Mat &im_index, const Mat &palette) {
//    Timer t("imencode_palette_bmp");

    CV_Assert(im_index.type() == CV_8UC1);
    CV_Assert(palette.cols == 1);
    CV_Assert(palette.type() == CV_8UC3);

    // NOTE 需要这个，因为bmp的数据是4对齐的，为了避免补0我们就这么暴力了 https://blog.csdn.net/u012877472/article/details/50272771
    CV_Assert(im_index.cols % 4 == 0);

    const int palette_num = palette.rows;
    int real_data_start_index = 14 + 40 + 4 * palette_num;
    const int file_size = real_data_start_index + im_index.cols * im_index.rows;
//    t("prepare");

    Mat bytes = create_tom_buffer(file_size);
//    t("allocate vector");

    int i = 0;
    i = write_bmp_file_header(bytes.data, i, file_size, real_data_start_index);
//    t("file_header");
    i = write_bmp_bitmap_information(bytes.data, i, im_index, palette_num);
//    t("bitmap_information");
    i = write_bmp_palette(bytes.data, i, palette);
//    t("palette");
    i = write_bmp_data(bytes.data, i, im_index);
//    t("data");
    CV_Assert(i == file_size);

    return bytes;
}

void imdecode_palette_bmp(const Mat &bytes, Mat &out_im_index, Mat &out_palette) {
//    Timer t("imdecode_palette_bmp");

    CV_Assert(is_tom_buffer(bytes));

    ReadBmpBitmapInfo info{};

    CV_Assert(bytes.rows > 14 + 40);

    int i = 0;
    i = read_bmp_file_header(bytes.data, i);
//    t("after read_bmp_file_header");
    i = read_bmp_bitmap_information(bytes.data, i, info);
    printf("info: w=%d h=%d pnum=%d\n", info.width, info.height, info.palette_num);
//    t("after read_bmp_bitmap_information");
    i = read_bmp_palette(bytes.data, i, info, out_palette);
//    t("after read_bmp_palette");
    i = read_bmp_data(bytes.data, i, info, out_im_index);
//    t("after read_bmp_data");
    CV_Assert(i == bytes.rows);

}
